{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# python notebook for Make Your Own Neural Network\n",
    "# code for a 3-layer neural network, and code for learning the MNIST dataset\n",
    "# this version trains using the MNIST dataset, then tests on our own images\n",
    "# (c) Tariq Rashid, 2016\n",
    "# license is GPLv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "# scipy.special for the sigmoid function expit()\n",
    "import scipy.special\n",
    "# library for plotting arrays\n",
    "import matplotlib.pyplot\n",
    "# ensure the plots are inside this notebook, not an external window\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper to load data from PNG image files\n",
    "import imageio\n",
    "# glob helps select multiple files using patterns\n",
    "import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# neural network class definition\n",
    "class neuralNetwork:\n",
    "    \n",
    "    \n",
    "    # initialise the neural network\n",
    "    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):\n",
    "        # set number of nodes in each input, hidden, output layer\n",
    "        self.inodes = inputnodes\n",
    "        self.hnodes = hiddennodes\n",
    "        self.onodes = outputnodes\n",
    "        \n",
    "        # link weight matrices, wih and who\n",
    "        # weights inside the arrays are w_i_j, where link is from node i to node j in the next layer\n",
    "        # w11 w21\n",
    "        # w12 w22 etc \n",
    "        self.wih = numpy.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))\n",
    "        self.who = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))\n",
    "\n",
    "        # learning rate\n",
    "        self.lr = learningrate\n",
    "        \n",
    "        # activation function is the sigmoid function\n",
    "        self.activation_function = lambda x: scipy.special.expit(x)\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # train the neural network\n",
    "    def train(self, inputs_list, targets_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        targets = numpy.array(targets_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        # output layer error is the (target - actual)\n",
    "        output_errors = targets - final_outputs\n",
    "        # hidden layer error is the output_errors, split by weights, recombined at hidden nodes\n",
    "        hidden_errors = numpy.dot(self.who.T, output_errors) \n",
    "        \n",
    "        # update the weights for the links between the hidden and output layers\n",
    "        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))\n",
    "        \n",
    "        # update the weights for the links between the input and hidden layers\n",
    "        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))\n",
    "        \n",
    "        pass\n",
    "\n",
    "    \n",
    "    # query the neural network\n",
    "    def query(self, inputs_list):\n",
    "        # convert inputs list to 2d array\n",
    "        inputs = numpy.array(inputs_list, ndmin=2).T\n",
    "        \n",
    "        # calculate signals into hidden layer\n",
    "        hidden_inputs = numpy.dot(self.wih, inputs)\n",
    "        # calculate the signals emerging from hidden layer\n",
    "        hidden_outputs = self.activation_function(hidden_inputs)\n",
    "        \n",
    "        # calculate signals into final output layer\n",
    "        final_inputs = numpy.dot(self.who, hidden_outputs)\n",
    "        # calculate the signals emerging from final output layer\n",
    "        final_outputs = self.activation_function(final_inputs)\n",
    "        \n",
    "        return final_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input, hidden and output nodes\n",
    "input_nodes = 784\n",
    "hidden_nodes = 200\n",
    "output_nodes = 10\n",
    "\n",
    "# learning rate\n",
    "learning_rate = 0.1\n",
    "\n",
    "# create instance of neural network\n",
    "n = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the mnist training data CSV file into a list\n",
    "training_data_file = open(\"mnist_dataset/mnist_train.csv\", 'r')\n",
    "training_data_list = training_data_file.readlines()\n",
    "training_data_file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the neural network\n",
    "\n",
    "# epochs is the number of times the training data set is used for training\n",
    "epochs = 10\n",
    "\n",
    "for e in range(epochs):\n",
    "    # go through all records in the training data set\n",
    "    for record in training_data_list:\n",
    "        # split the record by the ',' commas\n",
    "        all_values = record.split(',')\n",
    "        # scale and shift the inputs\n",
    "        inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01\n",
    "        # create the target output values (all 0.01, except the desired label which is 0.99)\n",
    "        targets = numpy.zeros(output_nodes) + 0.01\n",
    "        # all_values[0] is the target label for this record\n",
    "        targets[int(all_values[0])] = 0.99\n",
    "        n.train(inputs, targets)\n",
    "        pass\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading ...  my_own_images/2828_my_own_6.png\n",
      "0.01\n",
      "1.0\n",
      "loading ...  my_own_images/2828_my_own_5.png\n",
      "0.01\n",
      "0.86800003\n",
      "loading ...  my_own_images/2828_my_own_4.png\n",
      "0.01\n",
      "0.93011767\n",
      "loading ...  my_own_images/2828_my_own_2.png\n",
      "0.01\n",
      "1.0\n",
      "loading ...  my_own_images/2828_my_own_3.png\n",
      "0.01\n",
      "1.0\n"
     ]
    }
   ],
   "source": [
    "# our own image test data set\n",
    "our_own_dataset = []\n",
    "\n",
    "# load the png image data as test data set\n",
    "for image_file_name in glob.glob('my_own_images/2828_my_own_?.png'):\n",
    "    \n",
    "    # use the filename to set the correct label\n",
    "    label = int(image_file_name[-5:-4])\n",
    "    \n",
    "    # load image data from png files into an array\n",
    "    print (\"loading ... \", image_file_name)\n",
    "    img_array = imageio.imread(image_file_name, as_gray=True)\n",
    "    \n",
    "    # reshape from 28x28 to list of 784 values, invert values\n",
    "    img_data  = 255.0 - img_array.reshape(784)\n",
    "    \n",
    "    # then scale data to range from 0.01 to 1.0\n",
    "    img_data = (img_data / 255.0 * 0.99) + 0.01\n",
    "    print(numpy.min(img_data))\n",
    "    print(numpy.max(img_data))\n",
    "    \n",
    "    # append label and image data  to test data set\n",
    "    record = numpy.append(label,img_data)\n",
    "    our_own_dataset.append(record)\n",
    "    \n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[2.68263804e-03]\n",
      " [1.58367502e-03]\n",
      " [5.81142444e-02]\n",
      " [4.54185024e-04]\n",
      " [9.87810696e-01]\n",
      " [6.47654066e-03]\n",
      " [1.10899521e-03]\n",
      " [4.44764937e-03]\n",
      " [1.65304935e-04]\n",
      " [1.99217027e-04]]\n",
      "network says  4\n",
      "match!\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADbxJREFUeJzt3W+InOW5x/HflZhETYpkySZZNolbgxyqwkmPY6x6OKQUi5VCjKHSICWFctIXFVrpi+OfF/GFB8LhtGmRYyE5iU2hMS2k1uCftosUPYVaXUWqbTQV3bYx6+5GJX8wIcnmOi/2SVnjPvdMZp4/s17fD4SZea6597kyyW+fmbmfmdvcXQDimVV3AwDqQfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwR1UZU7W7RokQ8MDFS5SyCU4eFhHT582Fq5b0fhN7NbJP1Q0mxJ/+vuW1L3HxgY0NDQUCe7BJDQaDRavm/bT/vNbLak/5H0JUlXSdpgZle1+/MAVKuT1/yrJb3p7m+5+ylJeyStLaYtAGXrJPz9kv4+5fbBbNtHmNkmMxsys6Hx8fEOdgegSJ2Ef7o3FT72+WB33+buDXdv9Pb2drA7AEXqJPwHJS2fcnuZpEOdtQOgKp2E/0VJV5rZp81srqSvStpXTFsAytb2VJ+7nzGzuyT9WpNTfTvd/U+FdQagVB3N87v7U5KeKqgXABXi9F4gKMIPBEX4gaAIPxAU4QeCIvxAUJV+nh/VO3DgQLK+cOHCZJ1Tsj+5OPIDQRF+ICjCDwRF+IGgCD8QFOEHgmKq7xPgyJEjubVbb701OfbZZ58tuh3MEBz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAo5vk/AXbs2JFbu/3225Nj+/s/tsIaguDIDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBdTTPb2bDko5JmpB0xt0bRTSFjzp58mSyPjg4mFt75JFHim4HnxBFnOTzeXc/XMDPAVAhnvYDQXUafpf0GzN7ycw2FdEQgGp0+rT/Jnc/ZGaLJQ2a2evu/tzUO2S/FDZJ0ooVKzrcHYCidHTkd/dD2eWYpMckrZ7mPtvcveHuDdZ9A7pH2+E3s/lm9qlz1yV9UdJrRTUGoFydPO1fIukxMzv3c3a7+68K6QpA6doOv7u/JemfC+wFOfbs2ZOsz58/P7e2dOnSotu5IKdPn86tTUxMJMdefPHFRbeDKZjqA4Ii/EBQhB8IivADQRF+ICjCDwTFV3fPALt27UrWt2zZUlEnF+7EiRO5tSeffDI5dv369cn63Llz2+oJkzjyA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQzPN3gb179ybrPT09yfr1119fZDuFOnLkSG5t+/btybHr1q0ruh1MwZEfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Jinr8Cx48fT9bvvvvuZP3pp58usp1KjYyM5NZSn/WXpHnz5hXdDqbgyA8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTWd5zeznZK+LGnM3a/JtvVI+pmkAUnDku5w9w/Ka3Nme/jhh5P1G264IVm/+uqri2ynUvv378+t9ff3J8eaWdHtYIpWjvw/lnTLedvukfSMu18p6ZnsNoAZpGn43f05Se+ft3mtpHPLyOySdFvBfQEoWbuv+Ze4+4gkZZeLi2sJQBVKf8PPzDaZ2ZCZDY2Pj5e9OwAtajf8o2bWJ0nZ5VjeHd19m7s33L3R29vb5u4AFK3d8O+TtDG7vlHS48W0A6AqTcNvZo9K+r2kfzKzg2b2DUlbJN1sZn+RdHN2G8AM0nSe39035JS+UHAvM1azz+vv3r07WR8cHCyyna7ywgsv5NYuv/zyCjvB+TjDDwiK8ANBEX4gKMIPBEX4gaAIPxAUX91dgHvvvTdZv/POO5P1qGc+9vX11d1CaBz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAo5vlb9M477+TW3n777eTYhx56qOh2CjMxMZGsz549O1l392T9ww8/zK1deumlybEoF0d+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKef7MyZMnk/XrrrsutzYyMpIce//99yfro6Ojyfrp06eT9Vmz8n+Hnzp1Kjm203n+ZvU9e/bk1l5//fXk2KNHjybrzc4xWLZsWW5t+fLlybFLly5N1pstL37JJZck66l/s1StSBz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCopvP8ZrZT0pcljbn7Ndm2ByT9u6Tx7G73uftTZTVZhWbLbK9Zsya3tnjx4uTYAwcOJOvN5vF7enraHm9mybGXXXZZst5sHn9sbCxZT83Fv/HGG8mxzz//fLL+wQcfJOvvvvtubq3ZY9rs77VixYpkffPmzcl6as2CK664Ijl2zpw5yXqrWjny/1jSLdNs3+ruq7I/Mzr4QERNw+/uz0l6v4JeAFSok9f8d5nZH81sp5ktLKwjAJVoN/w/krRS0ipJI5K+l3dHM9tkZkNmNjQ+Pp53NwAVayv87j7q7hPuflbSdkmrE/fd5u4Nd29EXZAS6EZthd/Mpr5VuU7Sa8W0A6AqrUz1PSppjaRFZnZQ0mZJa8xslSSXNCzpmyX2CKAETcPv7hum2byjhF5qtXBh+j3L3bt3V9RJsc6ePZusd/rZ8dT38kvStddem1t78MEHk2PXr1/fVk9FOHPmTLL+3nvvJesLFixI1lPnT/B5fgClIvxAUIQfCIrwA0ERfiAowg8ExVd3Z5p9dHWmqmraKM+JEydya/Pmzauwkwtz0UXpaCxZsqSiTsrDkR8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmKeHx2ZO3du2/VmX72NcnHkB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgmOdHR5p97v3YsWO5Neb568WRHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCajrPb2bLJf1E0lJJZyVtc/cfmlmPpJ9JGpA0LOkOd2fiNhh3T9aXLVuWW2u2zDXK1cqR/4yk77r7ZyR9TtK3zOwqSfdIesbdr5T0THYbwAzRNPzuPuLuL2fXj0naL6lf0lpJu7K77ZJ0W1lNAijeBb3mN7MBSZ+V9AdJS9x9RJr8BSFpcdHNAShPy+E3swWS9kr6jrsfvYBxm8xsyMyGxsfH2+kRQAlaCr+ZzdFk8H/q7r/INo+aWV9W75M0Nt1Yd9/m7g13b/T29hbRM4ACNA2/mZmkHZL2u/v3p5T2SdqYXd8o6fHi2wNQllY+0nuTpK9JetXMXsm23Sdpi6Sfm9k3JP1N0lfKaRHdbPLYkG/r1q25tSeeeCI5ttk0YrN9I61p+N39d5LyHuUvFNsOgKpwhh8QFOEHgiL8QFCEHwiK8ANBEX4gKL66G6W68cYbc2t9fX3Jsczzl4sjPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTw/SjVrVv7xZeXKlRV2gvNx5AeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgmobfzJab2W/NbL+Z/cnMvp1tf8DM3jGzV7I/t5bfLoCitPJlHmckfdfdXzazT0l6ycwGs9pWd//v8toDUJam4Xf3EUkj2fVjZrZfUn/ZjQEo1wW95jezAUmflfSHbNNdZvZHM9tpZgtzxmwysyEzGxofH++oWQDFaTn8ZrZA0l5J33H3o5J+JGmlpFWafGbwvenGufs2d2+4e6O3t7eAlgEUoaXwm9kcTQb/p+7+C0ly91F3n3D3s5K2S1pdXpsAitbKu/0maYek/e7+/Snbpy6xuk7Sa8W3B6Asrbzbf5Okr0l61cxeybbdJ2mDma2S5JKGJX2zlA4BlKKVd/t/J2m6hdCfKr4dAFXhDD8gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ5u7V7cxsXNJfp2xaJOlwZQ1cmG7trVv7kuitXUX2drm7t/R9eZWG/2M7Nxty90ZtDSR0a2/d2pdEb+2qqzee9gNBEX4gqLrDv63m/ad0a2/d2pdEb+2qpbdaX/MDqE/dR34ANakl/GZ2i5m9YWZvmtk9dfSQx8yGzezVbOXhoZp72WlmY2b22pRtPWY2aGZ/yS6nXSatpt66YuXmxMrStT523bbideVP+81stqQDkm6WdFDSi5I2uPufK20kh5kNS2q4e+1zwmb2b5KOS/qJu1+TbfsvSe+7+5bsF+dCd/+PLuntAUnH6165OVtQpm/qytKSbpP0ddX42CX6ukM1PG51HPlXS3rT3d9y91OS9khaW0MfXc/dn5P0/nmb10ralV3fpcn/PJXL6a0ruPuIu7+cXT8m6dzK0rU+dom+alFH+Psl/X3K7YPqriW/XdJvzOwlM9tUdzPTWJItm35u+fTFNfdzvqYrN1fpvJWlu+axa2fF66LVEf7pVv/ppimHm9z9XyR9SdK3sqe3aE1LKzdXZZqVpbtCuyteF62O8B+UtHzK7WWSDtXQx7Tc/VB2OSbpMXXf6sOj5xZJzS7Hau7nH7pp5ebpVpZWFzx23bTidR3hf1HSlWb2aTObK+mrkvbV0MfHmNn87I0Ymdl8SV9U960+vE/Sxuz6RkmP19jLR3TLys15K0ur5seu21a8ruUkn2wq4weSZkva6e7/WXkT0zCzKzR5tJcmFzHdXWdvZvaopDWa/NTXqKTNkn4p6eeSVkj6m6SvuHvlb7zl9LZGk09d/7Fy87nX2BX39q+S/k/Sq5LOZpvv0+Tr69oeu0RfG1TD48YZfkBQnOEHBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCo/wc6g9+PxlWALQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# test the neural network with our own images\n",
    "\n",
    "# record to test\n",
    "item = 2\n",
    "\n",
    "# plot image\n",
    "matplotlib.pyplot.imshow(our_own_dataset[item][1:].reshape(28,28), cmap='Greys', interpolation='None')\n",
    "\n",
    "# correct answer is first value\n",
    "correct_label = our_own_dataset[item][0]\n",
    "# data is remaining values\n",
    "inputs = our_own_dataset[item][1:]\n",
    "\n",
    "# query the network\n",
    "outputs = n.query(inputs)\n",
    "print (outputs)\n",
    "\n",
    "# the index of the highest value corresponds to the label\n",
    "label = numpy.argmax(outputs)\n",
    "print(\"network says \", label)\n",
    "# append correct or incorrect to list\n",
    "if (label == correct_label):\n",
    "    print (\"match!\")\n",
    "else:\n",
    "    print (\"no match!\")\n",
    "    pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
